---
title: Winograd — Strassen algorithm
description: Consider a modification of Strassen's algorithm for square matrix multiplication with fewer number of summations between blocks than in the ordinary...
sections: [Multithreading,Block matrices,Comparing algorithms]
tags: [java,streams,arrays,multidimensional arrays,matrices,recursion,loops,nested loops]
canonical_url: /en/2022/02/11/winograd-strassen-algorithm.html
url_translated: /ru/2022/02/10/winograd-strassen-algorithm.html
title_translated: Алгоритм Винограда — Штрассена
date: 2022.02.11
lang: en
---

Consider a modification of Strassen's algorithm for square matrix multiplication with *fewer* number
of summations between blocks than in the ordinary algorithm — 15 instead of 18 and the same number
of multiplications as in the ordinary algorithm — 7. We will use Java Streams.

Recursive partitioning of matrices into blocks during multiplication makes sense up to a certain
limit, and then it loses its sense, since the Strassen's algorithm does not use cache of the execution
environment. Therefore, for small blocks we will use a parallel version of nested loops, and for large
blocks we will perform recursive partitioning in parallel.

We determine the boundary between the two algorithms experimentally — we adjust it to the cache of the
execution environment. The benefit of Strassen's algorithm becomes more evident on sizable matrices —
the difference with the algorithm using nested loops becomes larger and depends on the execution
environment. Let's compare the operating time of two algorithms.

*Algorithm using three nested loops: [Optimizing matrix multiplication]({{ '/en/2021/12/10/optimizing-matrix-multiplication.html' | relative_url }}).*

{% include heading.html text="Algorithm description" hash="algorithm-description" %}

Matrices must be the same size. We partition each matrix into 4 equally sized blocks. The blocks
must be square, therefore if this is not the case, then first we supplement the matrices with zero
rows and columns, and after that partition them into blocks. We will remove the redundant rows and
columns later from the resulting matrix.

{% include image_svg.html src="/img/block-matrices.svg" style="width:221pt; height:33pt;"
alt="{\displaystyle A={\begin{pmatrix}A_{11}&A_{12}\\A_{21}&A_{22}\end{pmatrix}},\quad B={\begin{pmatrix}B_{11}&B_{12}\\B_{21}&B_{22}\end{pmatrix}}.}" %}

Summation of blocks.

{% include image_svg.html src="/img/sums1.svg" style="width:101pt; height:148pt;"
alt="{\displaystyle{\begin{aligned}S_{1}&=(A_{21}+A_{22});\\S_{2}&=(S_{1}-A_{11});\\S_{3}&=(A_{11}-A_{21});\\S_{4}&=(A_{12}-S_{2});\\S_{5}&=(B_{12}-B_{11});\\S_{6}&=(B_{22}-S_{5});\\S_{7}&=(B_{22}-B_{12});\\S_{8}&=(S_{6}-B_{21}).\end{aligned}}}" %}

Multiplication of blocks.

{% include image_svg.html src="/img/products.svg" style="width:75pt; height:127pt;"
alt="{\displaystyle{\begin{aligned}P_{1}&=S_{2}S_{6};\\P_{2}&=A_{11}B_{11};\\P_{3}&=A_{12}B_{21};\\P_{4}&=S_{3}S_{7};\\P_{5}&=S_{1}S_{5};\\P_{6}&=S_{4}B_{22};\\P_{7}&=A_{22}S_{8}.\end{aligned}}}" %}

Summation of blocks.

{% include image_svg.html src="/img/sums2.svg" style="width:78pt; height:31pt;"
alt="{\displaystyle{\begin{aligned}T_{1}&=P_{1}+P_{2};\\T_{2}&=T_{1}+P_{4}.\end{aligned}}}" %}

Blocks of the resulting matrix.

{% include image_svg.html src="/img/sums3.svg" style="width:240pt; height:33pt;"
alt="{\displaystyle{\begin{pmatrix}C_{11}&C_{12}\\C_{21}&C_{22}\end{pmatrix}}={\begin{pmatrix}P_{2}+P_{3}&T_{1}+P_{5}+P_{6}\\T_{2}-P_{7}&T_{2}+P_{5}\end{pmatrix}}.}" %}

{% include heading.html text="Hybrid algorithm" hash="hybrid-algorithm" %}

We partition each matrix `A` and `B` into 4 equally sized blocks and, if necessary, we supplement
the missing parts with zeros. Perform 15 summations and 7 multiplications over the blocks — we get
4 blocks of the matrix `C`. Remove the redundant zeros, if added, and return the resulting matrix.
We run recursive partitioning of large blocks in parallel mode, and for small blocks we call the
algorithm with nested loops.

```java
/**
 * @param n   matrix size
 * @param brd minimum matrix size
 * @param a   first matrix 'n×n'
 * @param b   second matrix 'n×n'
 * @return resulting matrix 'n×n'
 */
public static int[][] multiplyMatrices(int n, int brd, int[][] a, int[][] b) {
    // multiply small blocks using algorithm with nested loops
    if (n < brd) return simpleMultiplication(n, a, b);
    // midpoint of the matrix, round up — blocks should
    // be square, if necessary add zero rows and columns
    int m = n - n / 2;
    // blocks of the first matrix
    int[][] a11 = getQuadrant(m, n, a, true, true);
    int[][] a12 = getQuadrant(m, n, a, true, false);
    int[][] a21 = getQuadrant(m, n, a, false, true);
    int[][] a22 = getQuadrant(m, n, a, false, false);
    // blocks of the second matrix
    int[][] b11 = getQuadrant(m, n, b, true, true);
    int[][] b12 = getQuadrant(m, n, b, true, false);
    int[][] b21 = getQuadrant(m, n, b, false, true);
    int[][] b22 = getQuadrant(m, n, b, false, false);
    // summation of blocks
    int[][] s1 = sumMatrices(m, a21, a22, true);
    int[][] s2 = sumMatrices(m, s1, a11, false);
    int[][] s3 = sumMatrices(m, a11, a21, false);
    int[][] s4 = sumMatrices(m, a12, s2, false);
    int[][] s5 = sumMatrices(m, b12, b11, false);
    int[][] s6 = sumMatrices(m, b22, s5, false);
    int[][] s7 = sumMatrices(m, b22, b12, false);
    int[][] s8 = sumMatrices(m, s6, b21, false);
    int[][][] p = new int[7][][];
    // multiplication of blocks in parallel streams
    IntStream.range(0, 7).parallel().forEach(i -> {
        switch (i) { // recursive calls
            case 0 -> p[i] = multiplyMatrices(m, brd, s2, s6);
            case 1 -> p[i] = multiplyMatrices(m, brd, a11, b11);
            case 2 -> p[i] = multiplyMatrices(m, brd, a12, b21);
            case 3 -> p[i] = multiplyMatrices(m, brd, s3, s7);
            case 4 -> p[i] = multiplyMatrices(m, brd, s1, s5);
            case 5 -> p[i] = multiplyMatrices(m, brd, s4, b22);
            case 6 -> p[i] = multiplyMatrices(m, brd, a22, s8);
        }
    });
    // summation of blocks
    int[][] t1 = sumMatrices(m, p[0], p[1], true);
    int[][] t2 = sumMatrices(m, t1, p[3], true);
    // blocks of the resulting matrix
    int[][] c11 = sumMatrices(m, p[1], p[2], true);
    int[][] c12 = sumMatrices(m, t1, sumMatrices(m, p[4], p[5], true), true);
    int[][] c21 = sumMatrices(m, t2, p[6], false);
    int[][] c22 = sumMatrices(m, t2, p[4], true);
    // assemble a matrix from blocks,
    // remove zero rows and columns, if added
    return putQuadrants(m, n, c11, c12, c21, c22);
}
```

{% capture collapsed_md %}
```java
// helper method for matrix summation
private static int[][] sumMatrices(int n, int[][] a, int[][] b, boolean sign) {
    int[][] c = new int[n][n];
    for (int i = 0; i < n; i++)
        for (int j = 0; j < n; j++)
            c[i][j] = sign ? a[i][j] + b[i][j] : a[i][j] - b[i][j];
    return c;
}
```
```java
// helper method, gets a block of a matrix
private static int[][] getQuadrant(int m, int n, int[][] x,
                                   boolean first, boolean second) {
    int[][] q = new int[m][m];
    if (first) for (int i = 0; i < m; i++)
        if (second) System.arraycopy(x[i], 0, q[i], 0, m); // x11
        else System.arraycopy(x[i], m, q[i], 0, n - m); // x12
    else for (int i = m; i < n; i++)
        if (second) System.arraycopy(x[i], 0, q[i - m], 0, m); // x21
        else System.arraycopy(x[i], m, q[i - m], 0, n - m); // x22
    return q;
}
```
```java
// helper method, assembles a matrix from blocks
private static int[][] putQuadrants(int m, int n,
                                    int[][] x11, int[][] x12,
                                    int[][] x21, int[][] x22) {
    int[][] x = new int[n][n];
    for (int i = 0; i < n; i++)
        if (i < m) {
            System.arraycopy(x11[i], 0, x[i], 0, m);
            System.arraycopy(x12[i], 0, x[i], m, n - m);
        } else {
            System.arraycopy(x21[i - m], 0, x[i], 0, m);
            System.arraycopy(x22[i - m], 0, x[i], m, n - m);
        }
    return x;
}
```
{% endcapture %}
{%- include collapsed_block.html summary="Helper methods" content=collapsed_md -%}

{% include heading.html text="Nested loops" hash="nested-loops" %}

To supplement the previous algorithm and to compare with it, we take the *optimized* variant of nested
loops, that uses cache of the execution environment better than others — processing of the rows of the
resulting matrix occurs independently of each other in parallel streams. For small matrices, we use this
algorithm — large matrices we partition into small blocks and use the same algorithm.

```java
/**
 * @param n matrix size
 * @param a first matrix 'n×n'
 * @param b second matrix 'n×n'
 * @return resulting matrix 'n×n'
 */
public static int[][] simpleMultiplication(int n, int[][] a, int[][] b) {
    // the resulting matrix
    int[][] c = new int[n][n];
    // bypass the rows of matrix 'a' in parallel mode
    IntStream.range(0, n).parallel().forEach(i -> {
        // bypass the indexes of the common side of two matrices:
        // the columns of matrix 'a' and the rows of matrix 'b'
        for (int k = 0; k < n; k++)
            // bypass the indexes of the columns of matrix 'b'
            for (int j = 0; j < n; j++)
                // the sum of the products of the elements of the i-th
                // row of matrix 'a' and the j-th column of matrix 'b'
                c[i][j] += a[i][k] * b[k][j];
    });
    return c;
}
```

{% include heading.html text="Testing" hash="testing" %}

To check, we take two square matrices `A=[1000×1000]` and `B=[1000×1000]`, filled with random numbers.
Take the minimum block size `[200×200]` elements. First, we compare the correctness of the implementation
of the two algorithms — matrix products must match. Then we execute each method 10 times and calculate
the average execution time.

```java
// start the program and output the result
public static void main(String[] args) {
    // incoming data
    int n = 1000, brd = 200, steps = 10;
    int[][] a = randomMatrix(n, n), b = randomMatrix(n, n);
    // matrix products
    int[][] c1 = multiplyMatrices(n, brd, a, b);
    int[][] c2 = simpleMultiplication(n, a, b);
    // check the correctness of the results
    System.out.println("The results match: " + Arrays.deepEquals(c1, c2));
    // measure the execution time of two methods
    benchmark("Hybrid algorithm", steps, () -> {
        int[][] c = multiplyMatrices(n, brd, a, b);
        if (!Arrays.deepEquals(c, c1)) System.out.print("error");
    });
    benchmark("Nested loops    ", steps, () -> {
        int[][] c = simpleMultiplication(n, a, b);
        if (!Arrays.deepEquals(c, c2)) System.out.print("error");
    });
}
```

{% capture collapsed_md %}
```java
// helper method, returns a matrix of the specified size
private static int[][] randomMatrix(int row, int col) {
    int[][] matrix = new int[row][col];
    for (int i = 0; i < row; i++)
        for (int j = 0; j < col; j++)
            matrix[i][j] = (int) (Math.random() * row * col);
    return matrix;
}
```
```java
// helper method for measuring the execution time of the passed code
private static void benchmark(String title, int steps, Runnable runnable) {
    long time, avg = 0;
    System.out.print(title);
    for (int i = 0; i < steps; i++) {
        time = System.currentTimeMillis();
        runnable.run();
        time = System.currentTimeMillis() - time;
        // execution time of one step
        System.out.print(" | " + time);
        avg += time;
    }
    // average execution time
    System.out.println(" || " + (avg / steps));
}
```
{% endcapture %}
{%- include collapsed_block.html summary="Helper methods" content=collapsed_md -%}

Output depends on the execution environment, time in milliseconds:

```
The results match: true
Hybrid algorithm | 196 | 177 | 156 | 205 | 154 | 165 | 133 | 118 | 132 | 134 || 157
Nested loops     | 165 | 164 | 168 | 167 | 168 | 168 | 170 | 179 | 173 | 168 || 169
```

{% include heading.html text="Comparing algorithms" hash="comparing-algorithms" %}

On an eight-core Linux x64 computer, execute the above test 100 times instead of 10. Take the minimum
block size `[brd=200]` elements. Change only `n` — sizes of both matrices `A=[n×n]` and `B=[n×n]`. Get
a summary table of results. Time in milliseconds.

```
               n | 900 | 1000 | 1100 | 1200 | 1300 | 1400 | 1500 | 1600 | 1700 |
-----------------|-----|------|------|------|------|------|------|------|------|
Hybrid algorithm |  96 |  125 |  169 |  204 |  260 |  313 |  384 |  482 |  581 |
Nested loops     | 119 |  162 |  235 |  281 |  361 |  497 |  651 |  793 |  971 |
```

Results: the benefit of the Strassen algorithm becomes more evident on large matrices, when the
size of the matrix itself is several times larger than the size of the minimal block, and depends
on the execution environment.

All the methods described above, including collapsed blocks, can be placed in one class.

{% capture collapsed_md %}
```java
import java.util.Arrays;
import java.util.stream.IntStream;
```
{% endcapture %}
{%- include collapsed_block.html summary="Required imports" content=collapsed_md -%}
